package com.ikingtech.framework.sdk.limit.gateway.filter;

import com.ikingtech.framework.sdk.context.constant.CommonConstants;
import com.ikingtech.framework.sdk.core.response.R;
import com.ikingtech.framework.sdk.limit.model.RateLimitResult;
import com.ikingtech.framework.sdk.limit.config.properties.RateLimitProperties;
import com.ikingtech.framework.sdk.limit.embedded.RateLimiterService;
import com.ikingtech.framework.sdk.utils.Tools;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.AntPathMatcher;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebExchangeDecorator;
import reactor.core.publisher.Mono;

import java.nio.charset.Charset;
import java.util.List;

/**
 * @author zhangqiang
 */
@RequiredArgsConstructor
public class GatewayRateLimitFilter implements GlobalFilter, Ordered {

    private final RateLimiterService rateLimiterService;

    private final RateLimitProperties rateLimitProperties;

    private final AntPathMatcher antPathMatcher = new AntPathMatcher();

    @Value("${spring.application.name}")
    private String name;

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        if (!rateLimitProperties.isEnable() || !(exchange instanceof ServerWebExchangeDecorator decorator)) {
            return chain.filter(exchange);
        }
        ServerHttpRequest request = decorator.getDelegate().getRequest();
        String path = request.getURI().getPath();
        String methodValue = request.getMethod().name();
        List<RateLimitProperties.Request> excludeList = rateLimitProperties.getExcludeList();
        RateLimitProperties.Request excludeMatch = match(methodValue, path, excludeList);
        //如果excludeList不空，并且该请求存在于其中，则放行该请求
        if (Tools.Coll.isNotBlank(excludeList) && excludeMatch != null) {
            return chain.filter(exchange);
        }
        List<RateLimitProperties.Request> includeList = rateLimitProperties.getIncludeList();
        //如果includeList为空，则说明需要限流所有的请求（没有跟下面的if合并在一起是为了方便理解）
        if (Tools.Coll.isBlank(includeList)) {
            String cacheKey = getCacheKey(path);
            RateLimitResult tryAcquire = rateLimiterService.tryAcquire(cacheKey, rateLimitProperties.getCount(), rateLimitProperties.getTime());
            if (Boolean.FALSE.equals(tryAcquire.getSuccess())) {
                ServerHttpResponse response = exchange.getResponse();
                return response.writeWith(rateRequest(response));
            }
        }
        RateLimitProperties.Request includeMatch = match(methodValue, path, includeList);
        //如果includeList不为空并且该请求存在于其中，那么对该请求执行限流操作
        if (Tools.Coll.isNotBlank(includeList) && includeMatch != null) {
            Integer time = getDefaultValue(includeMatch.getTime(), rateLimitProperties.getTime());
            Integer count = getDefaultValue(includeMatch.getCount(), rateLimitProperties.getCount());
            String cacheKey = getCacheKey(path);
            RateLimitResult tryAcquire = rateLimiterService.tryAcquire(cacheKey, count, time);
            if (Boolean.FALSE.equals(tryAcquire.getSuccess())) {
                ServerHttpResponse response = exchange.getResponse();
                return response.writeWith(rateRequest(response));
            }
        }
        //放行请求
        return chain.filter(exchange);
    }

    private Integer getDefaultValue(Integer propertiesRequestValue, Integer propertiesDefaultValue) {
        if (propertiesRequestValue == null || propertiesRequestValue <= 0) {
            return propertiesDefaultValue;
        }
        return propertiesRequestValue;
    }

    private Mono<DataBuffer> rateRequest(ServerHttpResponse response) {
        response.getHeaders().add(HttpHeaders.CONTENT_TYPE, "application/json");
        response.setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
        return Mono.just(response.bufferFactory().wrap(Tools.Json.toJsonStr(R.failed("访问次数过多，请稍后再试")).getBytes(Charset.defaultCharset())));
    }

    @Override
    public int getOrder() {
        return CommonConstants.AUTHENTICATION_FILTER_ORDER + 1;
    }

    private RateLimitProperties.Request match(String requestMethod, String requestUri, List<RateLimitProperties.Request> list) {
        if (Tools.Coll.isNotBlank(list)) {
            for (RateLimitProperties.Request request : list) {
                if (requestMethod.equalsIgnoreCase(request.getType()) && request.getUri() != null
                        && antPathMatcher.match(request.getUri(), requestUri)) {
                    return request;
                }
            }
        }
        return null;
    }

    private String getCacheKey(String requestUri) {
        return this.rateLimitProperties.getCacheKeyPrefix() +
                Tools.Str.COLON +
                name +
                Tools.Str.COLON +
                Tools.Str.COLON +
                requestUri;
    }
}
